import torch
from torch.utils.data import Dataset


class dataset(Dataset):

    def __init__(self, numOfTriple):
        self.tripleList = torch.LongTensor(range(numOfTriple))
        self.numOfTriple = numOfTriple

    def __len__(self):
        return self.numOfTriple

    def __getitem__(self, item):
        return self.tripleList[item]


class generateBatches:

    def __init__(self, batch, train2id, positiveBatch, corruptedBatch, numOfEntity, headRelation2Tail, tailRelation2Head):
        self.batch = batch
        self.train2id = train2id
        self.positiveBatch = positiveBatch
        self.corruptedBatch = corruptedBatch
        self.numOfEntity = numOfEntity
        self.headRelation2Tail = headRelation2Tail
        self.tailRelation2Head = tailRelation2Head

        self.generatePosAndCorBatch()

    def generatePosAndCorBatch(self):
        self.positiveBatch["h"] = []
        self.positiveBatch["r"] = []
        self.positiveBatch["t"] = []
        self.corruptedBatch["h"] = []
        self.corruptedBatch["r"] = []
        self.corruptedBatch["t"] = []
        for tripleId in self.batch:
            tmpHead = self.train2id["h"][tripleId]
            tmpRelation = self.train2id["r"][tripleId]
            tmpTail = self.train2id["t"][tripleId]
            self.positiveBatch["h"].append(tmpHead)
            self.positiveBatch["r"].append(tmpRelation)
            self.positiveBatch["t"].append(tmpTail)
            if torch.rand(1).item() >= 0.5:
                tmpCorruptedHead = torch.FloatTensor(1).uniform_(0, self.numOfEntity).long().item()
                while tmpCorruptedHead in self.tailRelation2Head[tmpTail][tmpRelation] or tmpCorruptedHead == tmpHead:
                    tmpCorruptedHead = torch.FloatTensor(1).uniform_(0, self.numOfEntity).long().item()
                tmpHead = tmpCorruptedHead
            else:
                tmpCorruptedTail = torch.FloatTensor(1).uniform_(0, self.numOfEntity).long().item()
                while tmpCorruptedTail in self.headRelation2Tail[tmpHead][tmpRelation] or tmpCorruptedTail == tmpTail:
                    tmpCorruptedTail = torch.FloatTensor(1).uniform_(0, self.numOfEntity).long().item()
                tmpTail = tmpCorruptedTail
            self.corruptedBatch["h"].append(tmpHead)
            self.corruptedBatch["r"].append(tmpRelation)
            self.corruptedBatch["t"].append(tmpTail)
        for aKey in self.positiveBatch:
            self.positiveBatch[aKey] = torch.LongTensor(self.positiveBatch[aKey])
        for aKey in self.corruptedBatch:
            self.corruptedBatch[aKey] = torch.LongTensor(self.corruptedBatch[aKey])
